package cn.dagteam.springboot.mongodb.starter.repository.support;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

import org.apache.commons.lang3.StringUtils;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.mapping.DBRef;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.query.MongoEntityInformation;
import org.springframework.data.mongodb.repository.support.SimpleMongoRepository;
import org.springframework.data.repository.support.PageableExecutionUtils;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;

import cn.dagteam.springboot.mongodb.starter.Constants;
import cn.dagteam.springboot.mongodb.starter.entity.IdEntity;
import cn.dagteam.springboot.mongodb.starter.entity.LogicDeleteEntity;
import cn.dagteam.springboot.mongodb.starter.repository.MongoDynamicSearchRepository;
import cn.dagteam.springboot.mongodb.starter.tools.Converts;

public class SimpleMongoDynamicSearchRepository<T, ID> extends SimpleMongoRepository<T, ID> implements MongoDynamicSearchRepository<T, ID> {

	private MongoTemplate mongoTemplate;

	private final MongoEntityInformation<T, ID> entityInformation;

	public SimpleMongoDynamicSearchRepository(MongoEntityInformation<T, ID> metadata, MongoOperations mongoOperations) {
		super(metadata, mongoOperations);
		this.entityInformation = metadata;
		// 用mongoTemplate代替了，所以不需要
		this.mongoTemplate = (MongoTemplate) mongoOperations;
	}

	@Override
	public Class<T> getEntityClass() {
		return entityInformation.getJavaType();
	}

	@Override
	public Optional<T> findOne(String fieldName, Object value) {
		return findOne(QueryFilter.builder().eq(fieldName, value).build());
	}

	@Override
	public Optional<T> findOne(List<QueryFilter> filters) {
		filters = cascadeFilter(filters);
		Query query = QueryRequestBuilder.builder().entityClass(entityInformation.getJavaType()).filters(filters).build();
		return Optional.ofNullable(mongoTemplate.findOne(query, entityInformation.getJavaType(), entityInformation.getCollectionName()));
	}

	@Override
	public List<T> findList(String fieldName, Object value) {
		return findList(fieldName, value, Sort.unsorted());
	}

	@Override
	public List<T> findList(String fieldName, Object value, Sort sort) {
		return findList(QueryFilter.builder().eq(fieldName, value).build(), sort);
	}

	@Override
	public List<T> findList(List<QueryFilter> filters) {
		return findList(filters, Sort.unsorted());
	}

	@Override
	public List<T> findList(List<QueryFilter> filters, Sort sort) {
		filters = cascadeFilter(filters);
		addDeletedFilter(filters);
		Query query = QueryRequestBuilder.builder().entityClass(entityInformation.getJavaType()).filters(filters).sort(sort).build();
		return mongoTemplate.find(query, entityInformation.getJavaType(), entityInformation.getCollectionName());
	}

	@Override
	public Page<T> findPage(List<QueryFilter> filters, int pageNo, int pageSize, Sort sort) {
		// 默认是从0页开始
		return findPage(filters, PageRequest.of(pageNo, pageSize, sort));
	}

	@Override
	public Page<T> findPage(List<QueryFilter> filters, Pageable pageable) {
		filters = cascadeFilter(filters);
		addDeletedFilter(filters);
		Query query = QueryRequestBuilder.builder().entityClass(entityInformation.getJavaType()).filters(filters).build().with(pageable);
		List<T> list = mongoTemplate.find(query, entityInformation.getJavaType(), entityInformation.getCollectionName());
		return PageableExecutionUtils.getPage(list, pageable,
				() -> mongoTemplate.count(query, entityInformation.getJavaType(), entityInformation.getCollectionName()));
	}

	@Override
	public int count(List<QueryFilter> filters) {
		filters = cascadeFilter(filters);
		addDeletedFilter(filters);
		Query query = QueryRequestBuilder.builder().entityClass(entityInformation.getJavaType()).filters(filters).build();
		return Converts.toInt(mongoTemplate.count(query, entityInformation.getJavaType(), entityInformation.getCollectionName()));
	}

	@Override
	public boolean exists(List<QueryFilter> filters) {
		filters = cascadeFilter(filters);
		addDeletedFilter(filters);
		Query query = QueryRequestBuilder.builder().entityClass(entityInformation.getJavaType()).filters(filters).build();
		return mongoTemplate.exists(query, entityInformation.getJavaType(), entityInformation.getCollectionName());
	}

	@Override
	public void deleteById(ID id) {
		delete(findById(id).orElseThrow(() -> new RuntimeException("The given Iterable of entity not be null!")));
	}

	@Override
	public void delete(T entity) {
		// 先不做逻辑删除
		if (entity instanceof LogicDeleteEntity) {
			((LogicDeleteEntity) entity).setDelFlag(Boolean.TRUE);
			super.save(entity);
		} else {
			super.deleteById(entityInformation.getRequiredId(entity));
		}
	}

	@Override
	public void deleteAll(Iterable<? extends T> entities) {
		Assert.notNull(entities, "The given Iterable of entities not be null!");

		entities.forEach(this::delete);
	}

	@Override
	public void deleteAll() {
		List<T> entities = findAll();
		entities.forEach(this::delete);
	}

	/**
	 * 增加逻辑删除的标志位
	 * 
	 * @param filters
	 */
	private void addDeletedFilter(List<QueryFilter> filters) {
		if (LogicDeleteEntity.class.isAssignableFrom(entityInformation.getJavaType())) {
			filters.add(new QueryFilter("delFlag", Operator.EQ, Boolean.FALSE));
		}
	}

	/**
	 * 级联查询的操作，只支持一层
	 * 
	 * @param filters
	 * @return
	 */
	private List<QueryFilter> cascadeFilter(List<QueryFilter> filters) {
		Map<String, List<QueryFilter>> map = new HashMap<>();
		List<QueryFilter> removeFilter = new ArrayList<>();
		Class<T> javaType = entityInformation.getJavaType();
		for (QueryFilter filter : filters) {
			// 找到包括.的查询条件，一般都是级联
			if (StringUtils.contains(filter.getFieldName(), ".")) {
				String mappingEntityName = StringUtils.substringBefore(filter.getFieldName(), ".");
				String mappingField = StringUtils.substringAfter(filter.getFieldName(), ".");
				Field f = ReflectionUtils.findField(javaType, mappingEntityName);
				// 必须是设置了DBRef的对象
				if (f.isAnnotationPresent(DBRef.class) && !"id".equals(mappingField)) {
					List<QueryFilter> mappingFilter = map.get(mappingEntityName);
					if (mappingFilter == null) {
						mappingFilter = new ArrayList<QueryFilter>();
						map.put(mappingEntityName, mappingFilter);
					}
					// 遇到Or语句的情况下，目前只支持同对象下的查询，or语句跨对象我也没办法了，太麻烦了
					if (StringUtils.contains(mappingField, Constants.FILTER_OR_OPERATOR)) {
						mappingField = StringUtils.remove(mappingField, mappingEntityName + ".");
					}
					QueryFilter mappingFiled = new QueryFilter(mappingField, filter.getOp(), filter.getValue());
					mappingFilter.add(mappingFiled);
					// 删除原有的查询条件
					removeFilter.add(filter);
				}
			}
		}
		filters.removeAll(removeFilter);
		for (Entry<String, List<QueryFilter>> entry : map.entrySet()) {
			Field f = ReflectionUtils.findField(javaType, entry.getKey());
			Query query = QueryRequestBuilder.builder().entityClass(f.getType()).filters(entry.getValue()).build();
			@SuppressWarnings("unchecked")
			List<? extends IdEntity> list = (List<? extends IdEntity>) mongoTemplate.find(query, f.getType());
			// 用in查询必须传入对象，只传id不行
			filters.add(new QueryFilter(entry.getKey(), Operator.IN, list));
		}
		return filters;
	}
}
